import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import matplotlib.dates as mdates
from datetime import datetime


# 设置Streamlit页面
st.title('新冠病毒感染数据分析')
st.write('上传CSV文件并分析各城市每天新增新冠病毒感染人数的数据。')
plt.rcParams['font.sans-serif']=['SimHei']    # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来显示负号

# 上传数据文件
uploaded_file = st.file_uploader("上传CSV文件", type="csv")
if uploaded_file is not None:
    # 读取数据
    data = pd.read_csv(uploaded_file)

    # 确保数据格式正确
    expected_columns = ['id', 'confirmedCount', 'confirmedIncr', 'curedCount', 'curedIncr', 'currentConfirmedCount',
                        'currentConfirmedIncr', 'dateId', 'deadCount', 'deadIncr', 'suspectedCount',
                        'suspectedCountIncr']
    if not all(column in data.columns for column in expected_columns):
        st.error('上传的数据文件格式不正确。')
    else:
        # 数据预处理
        data = data.dropna()  # 去除空值
        data = data[data['confirmedCount'] >= 0]  # 异常值处理
        data = data.sort_values(by='dateId')  # 按日期排序

        # 将 dateId 转换为 datetime 对象
        if pd.api.types.is_numeric_dtype(data['dateId']):
            data['dateId'] = pd.to_datetime(data['dateId'], format='%Y%m%d')

        # 数据分析
        st.subheader('总的新冠病毒感染趋势')
        fig1 = plt.figure()
        total_trend = data.groupby(data['dateId'].dt.date)['confirmedCount'].sum().plot(kind='line', ax=plt.gca())
        plt.title('Total Infection Trend')
        plt.xlabel('Date')
        plt.ylabel('Total Confirmed Count')
        st.pyplot(fig1)

        # 用户交互
        st.write(f"时间范围{data['dateId'].dt.date.min()}——{data['dateId'].dt.date.max()}")

        min_date = data['dateId'].dt.date.min()
        max_date = data['dateId'].dt.date.max()
        start_date = st.date_input('选择开始日期', value=min_date)
        end_date = st.date_input('选择结束日期', value=max_date)

        # 检查日期输入是否有效
        if start_date < min_date or end_date > max_date or start_date > end_date:
            st.error('日期选择有误，请选择一个在数据范围内的开始日期和结束日期，且开始日期不应晚于结束日期。')
        else:
            filtered_data = data[(data['dateId'].dt.date >= start_date) & (data['dateId'].dt.date <= end_date)]

            # 省份选择下拉菜单
            province_selection = st.selectbox('选择省份', filtered_data['provinceName'].unique())

            # 根据省份筛选数据
            province_data = filtered_data[filtered_data['provinceName'] == province_selection]

            # 设定感染人数阈值
            threshold = st.slider('设置感染人数阈值', min_value=0, max_value=1000, value=100)

            # 各城市感染曲线图
            city_ids = province_data['provinceName'].unique()
            for city_id in city_ids:
                city_data = province_data[province_data['provinceName'] == city_id]
                st.subheader(f'{city_data["provinceName"].iloc[0]}感染人数趋势')
                fig, ax = plt.subplots()  # 创建一个新的图表和轴对象

                # 绘制普通感染人数趋势
                city_grouped = city_data.groupby('dateId')['currentConfirmedCount'].mean()
                city_grouped.plot(kind='line', ax=ax, label='感染人数')

                # 标记超过阈值的城市
                high_risk_data = city_data[city_data['currentConfirmedCount'] > threshold]
                if not high_risk_data.empty:
                    high_risk_grouped = high_risk_data.groupby('dateId')['currentConfirmedCount'].mean()
                    high_risk_dates = high_risk_grouped.index
                    high_risk_counts = high_risk_grouped.values

                    # 在超过阈值的日期上添加红色点
                    ax.scatter(high_risk_dates, high_risk_counts, color='red', label='高风险区域')

                # 设置日期格式
                ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
                fig.autofmt_xdate()  # 自动调整日期显示

                ax.set_title('感染人数趋势')  # 设置图表标题
                ax.set_xlabel('日期')  # 设置x轴标签
                ax.set_ylabel('平均感染人数')  # 设置y轴标签
                ax.legend()  # 显示图例
                plt.savefig('temp_image.png')  # 保存为临时图像文件
                with open('temp_image.png', 'rb') as file:
                    image_bytes = file.read()
                st.image('temp_image.png', caption='感染人数趋势图', use_column_width=True)
                button = st.download_button(
                    label="下载图表",
                    data=image_bytes,
                    file_name=f'{city_data["provinceName"].iloc[0]}_infection_trend.png',
                    mime='application/octet-stream'
                )